import json
import logging
import os
from concurrent.futures import Executor
from typing import Iterator, Optional

import requests
from cachetools import TTLCache, cached

from dbgpt.core import MessageConverter, ModelOutput, ModelRequest, ModelRequestContext
from dbgpt.model.parameter import ProxyModelParameters
from dbgpt.model.proxy.base import ProxyLLMClient
from dbgpt.model.proxy.llms.proxy_model import ProxyModel

# https://cloud.baidu.com/doc/WENXINWORKSHOP/s/clntwmv7t
MODEL_VERSION_MAPPING = {
    "ERNIE-Bot-4.0": "completions_pro",
    "ERNIE-Bot-8K": "ernie_bot_8k",
    "ERNIE-Bot": "completions",
    "ERNIE-Bot-turbo": "eb-instant",
}

_DEFAULT_MODEL = "ERNIE-Bot"

logger = logging.getLogger(__name__)


@cached(TTLCache(1, 1800))
def _build_access_token(api_key: str, secret_key: str) -> str:
    """
    Generate Access token according AK, SK
    """

    url = "https://aip.baidubce.com/oauth/2.0/token"
    params = {
        "grant_type": "client_credentials",
        "client_id": api_key,
        "client_secret": secret_key,
    }

    res = requests.get(url=url, params=params)

    if res.status_code == 200:
        return res.json().get("access_token")


def _to_wenxin_messages(request: ModelRequest):
    """Convert messages to wenxin compatible format

    See https://cloud.baidu.com/doc/WENXINWORKSHOP/s/jlil56u11
    """
    messages, system_messages = request.split_messages()
    if len(system_messages) > 1:
        raise ValueError("Wenxin only support one system message")
    str_system_message = system_messages[0] if len(system_messages) > 0 else ""
    return messages, str_system_message


def wenxin_generate_stream(
    model: ProxyModel, tokenizer, params, device, context_len=2048
):
    client: WenxinLLMClient = model.proxy_llm_client
    context = ModelRequestContext(stream=True, user_name=params.get("user_name"))
    request = ModelRequest.build_request(
        client.default_model,
        messages=params["messages"],
        temperature=params.get("temperature"),
        context=context,
        max_new_tokens=params.get("max_new_tokens"),
    )
    for r in client.sync_generate_stream(request):
        yield r


class WenxinLLMClient(ProxyLLMClient):
    def __init__(
        self,
        model: Optional[str] = None,
        api_key: Optional[str] = None,
        api_secret: Optional[str] = None,
        model_version: Optional[str] = None,
        model_alias: Optional[str] = "wenxin_proxyllm",
        context_length: Optional[int] = 8192,
        executor: Optional[Executor] = None,
    ):
        if not model:
            model = _DEFAULT_MODEL
        if not api_key:
            api_key = os.getenv("WEN_XIN_API_KEY")
        if not api_secret:
            api_secret = os.getenv("WEN_XIN_API_SECRET")
        if not model_version:
            if model:
                model_version = MODEL_VERSION_MAPPING.get(model)
            else:
                model_version = os.getenv("WEN_XIN_MODEL_VERSION")
        if not api_key:
            raise ValueError("api_key can't be empty")
        if not api_secret:
            raise ValueError("api_secret can't be empty")
        if not model_version:
            raise ValueError("model_version can't be empty")
        self._model = model
        self._api_key = api_key
        self._api_secret = api_secret
        self._model_version = model_version

        super().__init__(
            model_names=[model, model_alias],
            context_length=context_length,
            executor=executor,
        )

    @classmethod
    def new_client(
        cls,
        model_params: ProxyModelParameters,
        default_executor: Optional[Executor] = None,
    ) -> "WenxinLLMClient":
        return cls(
            model=model_params.proxyllm_backend,
            api_key=model_params.proxy_api_key,
            api_secret=model_params.proxy_api_secret,
            model_version=model_params.proxy_api_version,
            model_alias=model_params.model_name,
            context_length=model_params.max_context_size,
            executor=default_executor,
        )

    @property
    def default_model(self) -> str:
        return self._model

    def sync_generate_stream(
        self,
        request: ModelRequest,
        message_converter: Optional[MessageConverter] = None,
    ) -> Iterator[ModelOutput]:
        request = self.local_covert_message(request, message_converter)

        try:
            access_token = _build_access_token(self._api_key, self._api_secret)

            headers = {"Content-Type": "application/json", "Accept": "application/json"}

            proxy_server_url = f"https://aip.baidubce.com/rpc/2.0/ai_custom/v1/wenxinworkshop/chat/{self._model_version}?access_token={access_token}"

            if not access_token:
                raise RuntimeError(
                    "Failed to get access token. please set the correct api_key and secret key."
                )

            history, system_message = _to_wenxin_messages(request)
            payload = {
                "messages": history,
                "system": system_message,
                "temperature": request.temperature,
                "stream": True,
            }

            text = ""
            res = requests.post(
                proxy_server_url, headers=headers, json=payload, stream=True
            )
            logger.info(
                f"Send request to {proxy_server_url} with real model {self._model}, model version {self._model_version}"
            )
            for line in res.iter_lines():
                if line:
                    if not line.startswith(b"data: "):
                        error_message = line.decode("utf-8")
                        yield ModelOutput(text=error_message, error_code=1)
                    else:
                        json_data = line.split(b": ", 1)[1]
                        decoded_line = json_data.decode("utf-8")
                        if decoded_line.lower() != "[DONE]".lower():
                            obj = json.loads(json_data)
                            if obj["result"] is not None:
                                content = obj["result"]
                                text += content
                        yield ModelOutput(text=text, error_code=0)
        except Exception as e:
            return ModelOutput(
                text=f"**LLMServer Generate Error, Please CheckErrorInfo.**: {e}",
                error_code=1,
            )
